-
Notifications
You must be signed in to change notification settings - Fork 497
Support llama3 autoparallel + pipelining #1657
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: autoparallel
Are you sure you want to change the base?
Conversation
so far just tested locally `LOG_RANK=4 CONFIG_FILE=././torchtitan/models/deepseek_v3/train_configs/debug_model.toml ./run_train.sh --model.name llama3_auto_parallel --parallelism.pipeline_parallel_degree 2 --training.steps 100` Runs and loss converges. Left one TODO about global-batch-size and gradient accumulation
|
||
pp_degree = job_config.parallelism.pipeline_parallel_degree |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
unused pp degree config, should probably raise error when its not local world size
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
i deleted it (it was unused/unneeded). I don't think we need to raise any error. pp_degree does not need to equal any particular size, and pp can even be disabled.
spmd_dims.append("tp") | ||
spmd_mesh = world_mesh[spmd_dims] | ||
|
||
dp_degree = 1 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same, config could specify dp_degree
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
inputs, target=targets, losses=losses, input_batch=inputs | ||
# TODO: input_batch kwarg only needed for CP, but | ||
# autoparallel doesn't accept kwargs in its forward | ||
inputs, target=targets, losses=losses #, input_batch=inputs |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Curious, why does CP need input_batch
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I assumed you would know. Am I wrong?
|
||
pp_degree = job_config.parallelism.pipeline_parallel_degree | ||
local_batch_size = job_config.training.local_batch_size | ||
spmd_batch_size = local_batch_size |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
oops this is a bug for the non-pp case. should be local *dp degree
and put in an 'else' branch
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
fixed
@@ -492,11 +492,13 @@ def forward_backward_step( | |||
) | |||
if self.pp_has_first_stage: | |||
self.pp_schedule.step( | |||
inputs, target=targets, losses=losses, input_batch=inputs | |||
# TODO: input_batch kwarg only needed for CP, but | |||
# autoparallel doesn't accept kwargs in its forward |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we just fix this LOL
# # step. | ||
# dp_degree = parallel_dims.dp_replicate * parallel_dims.dp_shard | ||
# global_batch_size = job_config.training.local_batch_size * dp_degree | ||
if parallel_dims.pp_enabled and pp_rank > 0: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What a mess. No action here needed, but it's definitely worth thinking about what the terminal UX state here should be.
so far just tested locally
LOG_RANK=4 CONFIG_FILE=././torchtitan/models/deepseek_v3/train_configs/debug_model.toml ./run_train.sh --model.name llama3_auto_parallel --parallelism.pipeline_parallel_degree 2 --training.steps 100
Runs and loss converges.
Left one TODO about global-batch-size and gradient accumulation